package com.jfire.socket.socketclient;

import java.io.IOException;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.nio.channels.AsynchronousSocketChannel;
import java.nio.charset.Charset;
import java.util.Random;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import com.jfire.baseutil.collection.ByteBufferPool;
import com.jfire.baseutil.encrypt.RSAUtil;
import com.jfire.socket.socketclient.filter.AesFilter;
import com.jfire.socket.socketclient.listen.GetReadResult;
import com.jfire.socket.socketserver.bus.ActionType;
import com.jfire.socket.socketserver.bus.BusinessCommand;
import com.jfire.socket.socketserver.exception.AuthException;
import com.jfire.socket.socketserver.exception.ConnectErrorException;
import com.jfire.socket.socketserver.exception.LessThanProtocolException;
import com.jfire.socket.socketserver.exception.NotFitProtocolException;
import com.jfire.socket.socketserver.exception.NotIntactDataException;
import com.jfire.socket.socketserver.exception.TimeoutException;
import com.jfire.socket.socketserver.util.CheckReadBuffer;

public class AesClient extends Client
{
    private RSAUtil    rsaUtil;
    private ByteBuffer authBuffer = ByteBufferPool.getBuffer(1024);
    private Charset    charset    = Charset.forName("utf-8");
    private byte[]     okCheck    = "ok".getBytes(charset);
    private Random     random     = new Random();
    
    /**
     * 设置基本的加解密过滤器参数
     * 
     * @param endecrptFilter
     */
    public AesClient(byte[] publicKey, GetReadResult getReadResult)
    {
        super(getReadResult);
        this.aesFilter = new AesFilter();
        rsaUtil = new RSAUtil();
        rsaUtil.setPublicKey(publicKey);
    }
    
    @Override
    protected void getClientChannelInfo(boolean reuse) throws ConnectErrorException
    {
        if (reuse && clientChannelInfo != null && clientChannelInfo.expired() == false)
        {
            return;
        }
        if (clientChannelInfo != null)
        {
            clientChannelInfo.closeSocket();
        }
        boolean finishConnect = false;
        int tryTime = connectTrytimes;
        AsynchronousSocketChannel socketChannel = null;
        String address = null;
        do
        {
            try
            {
                socketChannel = AsynchronousSocketChannel.open(ChannelGroupSource.getChannelGroup());
                socketChannel.connect(new InetSocketAddress(ip, port)).get(5, TimeUnit.SECONDS);
                address = socketChannel.getLocalAddress().toString() + "--->" + socketChannel.getRemoteAddress().toString();
                logger.debug("准备开始加密登录服务器");
                initAuth(socketChannel);
                finishConnect = true;
                break;
            }
            catch (Exception e)
            {
                logger.error("登录验证失败", e);
                try
                {
                    if (socketChannel != null)
                    {
                        socketChannel.close();
                    }
                }
                catch (Exception e1)
                {
                    ;
                }
                tryTime--;
            }
        } while (tryTime > 0);
        if (finishConnect == false)
        {
            throw new ConnectErrorException("加密客户端经过" + connectTrytimes + "次测试，无法完成连接");
        }
        clientChannelInfo = buildChannelInfo(address, socketChannel);
        clientChannelInfo.startReadWait();
    }
    
    /**
     * 关闭客户端，并且关闭连接。归还缓存资源
     */
    public void close()
    {
        super.close();
        ByteBufferPool.returnBuffer(authBuffer);
    }
    
    /**
     * 初始化认证信息
     * 
     * @throws ExecutionException
     * @throws InterruptedException
     * @throws IOException
     * @throws ConnectErrorException
     * @throws TimeoutException
     * @throws AuthException
     */
    private void initAuth(AsynchronousSocketChannel socketChannel) throws InterruptedException, ExecutionException, IOException, ConnectErrorException, TimeoutException, AuthException
    {
        /********** 登陆验证 ***********/
        // 准备登陆信息
        String authMsg = preAuthData(authBuffer);
        logger.trace("发送登录信息");
        byte[] okResponse = sendAndGetResponse(authBuffer, socketChannel);
        // 签名验证失败则抛出异常并且中断流程
        if (rsaUtil.check(authMsg.getBytes(charset), okResponse) == false)
        {
            logger.trace("平台签名验证失败");
            throw new AuthException();
        }
        logger.trace("平台签名验证成功");
        /********** 登陆验证 ***********/
        /*********** 客户端发送后续报文加密密钥 **************/
        // 客户端生成随机对称密钥，使用rsa加密后传输给服务器
        byte[] key = preSendKeyData(authBuffer);
        okResponse = sendAndGetResponse(authBuffer, socketChannel);
        // 签名验证失败则抛出异常并且中断流程
        if (rsaUtil.check(okCheck, okResponse) == false)
        {
            throw new AuthException();
        }
        /*********** 客户端发送后续报文加密密钥 **************/
        aesFilter.setKey(key);
    }
    
    /**
     * 往buffer中写入登陆验证信息并且反转buffer
     * 
     * @param authBuffer
     */
    private String preAuthData(ByteBuffer authBuffer)
    {
        authBuffer.clear();
        String authMsg = "jfire" + System.currentTimeMillis();
        byte[] data = rsaUtil.encrypt(authMsg.getBytes(charset));
        headFactory.putHeadInBuffer(authBuffer, BusinessCommand.AUTH, 0);
        authBuffer.put(ActionType.CLIENT_SEND).put(BusinessCommand.AUTH).put(BusinessCommand.REQUEST_SUCCESS).putInt(data.length).put(data).flip();
        return authMsg;
    }
    
    private byte[] sendAndGetResponse(ByteBuffer authBuffer, AsynchronousSocketChannel socketChannel) throws InterruptedException, ExecutionException, TimeoutException, ConnectErrorException
    {
        socketChannel.write(authBuffer).get();
        // 读取服务器对登陆信息的确认消息
        // 服务端会返回登陆信息的签名消息
        authBuffer = read(socketChannel, authBuffer);
        authBuffer.position(7);
        byte[] signResponse = new byte[authBuffer.getInt()];
        authBuffer.get(signResponse);
        return signResponse;
    }
    
    /**
     * 往buffer中写入由rsa加密后的对称密钥
     * 
     * @param authBuffer
     */
    private byte[] preSendKeyData(ByteBuffer authBuffer)
    {
        authBuffer.clear();
        byte[] key = new byte[16];
        random.nextBytes(key);
        byte[] result = rsaUtil.encrypt(key);
        headFactory.putHeadInBuffer(authBuffer, BusinessCommand.SENDKEY, 0);
        authBuffer.put(ActionType.CLIENT_SEND).put(BusinessCommand.SENDKEY).put(BusinessCommand.REQUEST_SUCCESS).putInt(result.length).put(result).flip();
        return key;
    }
    
    /**
     * 进行阻塞式的数据读取
     * 
     * @param socketChannel
     * @param authBuffer
     * @return
     * @throws TimeoutException
     * @throws ConnectErrorException
     */
    private ByteBuffer read(AsynchronousSocketChannel socketChannel, ByteBuffer authBuffer) throws TimeoutException, ConnectErrorException
    {
        authBuffer.clear();
        long tmpReadTiemout = readTimeout;
        int length = -1;
        int readResult = -1;
        do
        {
            try
            {
                long t0 = System.currentTimeMillis();
                readResult = socketChannel.read(authBuffer).get(tmpReadTiemout, TimeUnit.MILLISECONDS);
                tmpReadTiemout -= System.currentTimeMillis() - t0;
                authBuffer.flip();
                length = CheckReadBuffer.checkReadBuffer(authBuffer, headFactory);
                if (length == authBuffer.remaining() - 11)
                {
                    break;
                }
                authBuffer.compact();
            }
            catch (LessThanProtocolException e)
            {
                ;
            }
            catch (NotFitProtocolException e)
            {
                throw new ConnectErrorException(e);
            }
            catch (NotIntactDataException e)
            {
                authBuffer = ByteBufferPool.expandToSize(e.getNeedSize(), authBuffer);
            }
            catch (Exception e)
            {
                throw new ConnectErrorException(e);
            }
        } while (tmpReadTiemout > 0 && readResult != -1);
        if ((length + 11) != authBuffer.remaining())
        {
            logger.error("认证读取超时");
            throw new TimeoutException(readTimeout);
        }
        return authBuffer;
    }
}
